{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transformer 结构\n",
    "\n",
    "# 1. FFN层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class PositionwiseFeedforwardLayer(nn.Module):\n",
    "    def __init__(self, hid_dim, pf_dim, dropout):\n",
    "        super().__init__() \n",
    "        self.fc_1 = nn.Linear(hid_dim, pf_dim)\n",
    "        self.fc_2 = nn.Linear(pf_dim, hid_dim)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, x): \n",
    "        #x = [batch size, seq len, hid dim]\n",
    "        x = self.dropout(torch.relu(self.fc_1(x))) \n",
    "        #x = [batch size, seq len, pf dim] \n",
    "        x = self.fc_2(x)\n",
    "        #x = [batch size, seq len, hid dim]\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 多头注意力机制"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiHeadAttentionLayer(nn.Module):\n",
    "    def __init__(self, hid_dim, n_heads, dropout, device):\n",
    "        super().__init__()\n",
    "        assert hid_dim % n_heads == 0\n",
    "        self.hid_dim = hid_dim\n",
    "        self.n_heads = n_heads\n",
    "        self.head_dim = hid_dim // n_heads\n",
    "        self.fc_q = nn.Linear(hid_dim, hid_dim)\n",
    "        self.fc_k = nn.Linear(hid_dim, hid_dim)\n",
    "        self.fc_v = nn.Linear(hid_dim, hid_dim)\n",
    "        self.fc_o = nn.Linear(hid_dim, hid_dim)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)\n",
    "    def forward(self, query, key, value, mask = None):\n",
    "        batch_size = query.shape[0]\n",
    "        #query = [batch size, query len, hid dim]\n",
    "        #key = [batch size, key len, hid dim]\n",
    "        #value = [batch size, value len, hid dim]\n",
    "        Q = self.fc_q(query)\n",
    "        K = self.fc_k(key)\n",
    "        V = self.fc_v(value)\n",
    "        #Q = [batch size, query len, hid dim]\n",
    "        #K = [batch size, key len, hid dim]\n",
    "        #V = [batch size, value len, hid dim]\n",
    "        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)\n",
    "        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)\n",
    "        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)\n",
    "        #Q = [batch size, n heads, query len, head dim]\n",
    "        #K = [batch size, n heads, key len, head dim]\n",
    "        #V = [batch size, n heads, value len, head dim]\n",
    "        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale\n",
    "        #energy = [batch size, n heads, query len, key len]\n",
    "        if mask is not None:\n",
    "            energy = energy.masked_fill(mask == 0, -1e10)\n",
    "        attention = torch.softmax(energy, dim = -1)\n",
    "        #attention = [batch size, n heads, query len, key len]\n",
    "        x = torch.matmul(self.dropout(attention), V)\n",
    "        #x = [batch size, n heads, query len, head dim]\n",
    "        x = x.permute(0, 2, 1, 3).contiguous()\n",
    "        #x = [batch size, query len, n heads, head dim]\n",
    "        x = x.view(batch_size, -1, self.hid_dim)\n",
    "        #x = [batch size, query len, hid dim]\n",
    "        x = self.fc_o(x)\n",
    "        #x = [batch size, query len, hid dim]\n",
    "        return x, attention "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Encoder层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncoderLayer(nn.Module):\n",
    "    def __init__(self, \n",
    "                 hid_dim, \n",
    "                 n_heads, \n",
    "                 pf_dim,  \n",
    "                 dropout, \n",
    "                 device):\n",
    "        super().__init__()        \n",
    "        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)\n",
    "        self.ff_layer_norm = nn.LayerNorm(hid_dim)\n",
    "        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)\n",
    "        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, \n",
    "                                                                     pf_dim, \n",
    "                                                                     dropout)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, src, src_mask):        \n",
    "        #src = [batch size, src len, hid dim]\n",
    "        #src_mask = [batch size, src len]                \n",
    "        #self attention\n",
    "        _src, _ = self.self_attention(src, src, src, src_mask)        \n",
    "        #dropout, residual connection and layer norm\n",
    "        src = self.self_attn_layer_norm(src + self.dropout(_src))        \n",
    "        #src = [batch size, src len, hid dim]        \n",
    "        #positionwise feedforward\n",
    "        _src = self.positionwise_feedforward(src)        \n",
    "        #dropout, residual and layer norm\n",
    "        src = self.ff_layer_norm(src + self.dropout(_src))        \n",
    "        #src = [batch size, src len, hid dim]        \n",
    "        return src "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, \n",
    "                 input_dim, \n",
    "                 hid_dim, \n",
    "                 n_layers, \n",
    "                 n_heads, \n",
    "                 pf_dim,\n",
    "                 dropout, \n",
    "                 device,\n",
    "                 max_length = 100):\n",
    "        super().__init__()\n",
    "\n",
    "        self.device = device\n",
    "        self.tok_embedding = nn.Embedding(input_dim, hid_dim)\n",
    "        self.pos_embedding = nn.Embedding(max_length, hid_dim)\n",
    "        self.layers = nn.ModuleList([EncoderLayer(hid_dim, \n",
    "                                                  n_heads, \n",
    "                                                  pf_dim,\n",
    "                                                  dropout, \n",
    "                                                  device) \n",
    "                                     for _ in range(n_layers)])\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)\n",
    "    def forward(self, src, src_mask):\n",
    "        #src = [batch size, src len]\n",
    "        #src_mask = [batch size, src len]\n",
    "        batch_size = src.shape[0]\n",
    "        src_len = src.shape[1]\n",
    "        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)\n",
    "        #pos = [batch size, src len]\n",
    "        src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))\n",
    "        #src = [batch size, src len, hid dim]\n",
    "        for layer in self.layers:\n",
    "            src = layer(src, src_mask)\n",
    "        #src = [batch size, src len, hid dim]\n",
    "        return src"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Decoder层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DecoderLayer(nn.Module):\n",
    "    def __init__(self, \n",
    "                 hid_dim, \n",
    "                 n_heads, \n",
    "                 pf_dim, \n",
    "                 dropout, \n",
    "                 device):\n",
    "        super().__init__()\n",
    "        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)\n",
    "        self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)\n",
    "        self.ff_layer_norm = nn.LayerNorm(hid_dim)\n",
    "        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)\n",
    "        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)\n",
    "        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, \n",
    "                                                                     pf_dim, \n",
    "                                                                     dropout)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "    def forward(self, trg, enc_src, trg_mask, src_mask):\n",
    "        #trg = [batch size, trg len, hid dim]\n",
    "        #enc_src = [batch size, src len, hid dim]\n",
    "        #trg_mask = [batch size, trg len]\n",
    "        #src_mask = [batch size, src len]\n",
    "        #self attention\n",
    "        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)\n",
    "        #dropout, residual connection and layer norm\n",
    "        trg = self.self_attn_layer_norm(trg + self.dropout(_trg))\n",
    "        #trg = [batch size, trg len, hid dim]\n",
    "        #encoder attention\n",
    "        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)\n",
    "        #dropout, residual connection and layer norm\n",
    "        trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))\n",
    "        #trg = [batch size, trg len, hid dim]\n",
    "        #positionwise feedforward\n",
    "        _trg = self.positionwise_feedforward(trg)\n",
    "        #dropout, residual and layer norm\n",
    "        trg = self.ff_layer_norm(trg + self.dropout(_trg))\n",
    "        #trg = [batch size, trg len, hid dim]\n",
    "        #attention = [batch size, n heads, trg len, src len]\n",
    "        return trg, attention"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(self, \n",
    "                 output_dim, \n",
    "                 hid_dim, \n",
    "                 n_layers, \n",
    "                 n_heads, \n",
    "                 pf_dim, \n",
    "                 dropout, \n",
    "                 device,\n",
    "                 max_length = 100):\n",
    "        super().__init__()\n",
    "        self.device = device\n",
    "        self.tok_embedding = nn.Embedding(output_dim, hid_dim)\n",
    "        self.pos_embedding = nn.Embedding(max_length, hid_dim)\n",
    "        self.layers = nn.ModuleList([DecoderLayer(hid_dim, \n",
    "                                                  n_heads, \n",
    "                                                  pf_dim, \n",
    "                                                  dropout, \n",
    "                                                  device)\n",
    "                                     for _ in range(n_layers)])\n",
    "        self.fc_out = nn.Linear(hid_dim, output_dim)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)\n",
    "    def forward(self, trg, enc_src, trg_mask, src_mask):\n",
    "        #trg = [batch size, trg len]\n",
    "        #enc_src = [batch size, src len, hid dim]\n",
    "        #trg_mask = [batch size, trg len]\n",
    "        #src_mask = [batch size, src len]\n",
    "        batch_size = trg.shape[0]\n",
    "        trg_len = trg.shape[1]\n",
    "        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)\n",
    "        #pos = [batch size, trg len]\n",
    "        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))\n",
    "        #trg = [batch size, trg len, hid dim]\n",
    "        for layer in self.layers:\n",
    "            trg, attention = layer(trg, enc_src, trg_mask, src_mask)\n",
    "        #trg = [batch size, trg len, hid dim]\n",
    "        #attention = [batch size, n heads, trg len, src len]\n",
    "        output = self.fc_out(trg)\n",
    "        #output = [batch size, trg len, output dim]\n",
    "        return output, attention "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Transformer结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Transformer(nn.Module):\n",
    "    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device) -> None:\n",
    "        super().__init__()\n",
    "        self.encoder     = encoder\n",
    "        self.decoder     = decoder\n",
    "        self.src_pad_idx = src_pad_idx\n",
    "        self.trg_pad_idx = trg_pad_idx\n",
    "        self.device      = device\n",
    "        \n",
    "    def make_src_mask(self, src):\n",
    "        #! src shape = [batch_size, src_seq_len]\n",
    "        # 填充的地方为0， 非填充的地方为1\n",
    "        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)\n",
    "        return src_mask # shape = [batch_size, 1, 1, src_seq_len]\n",
    "    \n",
    "    def make_trg_mask(self, trg):\n",
    "        #  [batch_size, 1, 1 trg_seq_len]\n",
    "        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)\n",
    "        trg_len      = trg.shape[1]\n",
    "        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()  #! 下三角矩阵\n",
    "        return trg_pad_mask & trg_sub_mask    # shape = [batch_size, 1, trg_seq_len, trg_seq_len]\n",
    "        \n",
    "        \n",
    "        \n",
    "    def forward(self, src, trg):\n",
    "        src_mask = self.make_src_mask(src)          # [batch_size, 1, 1, src_seq_len]\n",
    "        trg_mask = self.make_trg_mask(trg)          # [batch_size, 1, trg_seq_len, trg_seq_len]\n",
    "        enc_src  = self.encoder(src)                # [batch_size, src_seq_len, hidden_dim]\n",
    "        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)\n",
    "        #output    = [batch_size, trg_seq_len, output_dim]\n",
    "        #attention = [batch_size, n_heads, trg_seq_len, src_seq_len]\n",
    "        return output, attention"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
